Keras 训练时不用将数据全部加入内存 您所在的位置:网站首页 keras fit train_on_batch Keras 训练时不用将数据全部加入内存

Keras 训练时不用将数据全部加入内存

2023-07-17 04:48| 来源: 网络整理| 查看: 265

How can I use Keras with datasets that don't fit in memory?

You can do batch training using model.train_on_batch(X, y) and model.test_on_batch(X, y). See the models documentation.

Alternatively, you can write a generator that yields batches of training data and use the methodmodel.fit_generator(data_generator, steps_per_epoch, epochs).

You can see batch training in action in our CIFAR10 example.

代码1:图像分类

import codecs import cv2 from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img from keras.layers import * from keras.models import * from keras.callbacks import * from visual_callbacks import AccLossPlotter plotter = AccLossPlotter(graphs=['acc', 'loss'], save_graph=True) class LossHistory(Callback): def on_train_begin(self, logs={}): self.losses = [] def on_batch_end(self, batch, logs={}): self.losses.append(logs.get('loss')) datagen = ImageDataGenerator( rotation_range=0, width_shift_range=0.1, height_shift_range=0.1, rescale=1./255, shear_range=0.1, zoom_range=0.1, horizontal_flip=False, fill_mode='nearest') train_generator = datagen.flow_from_directory( r'chars_rec\train', # this is the target directory target_size=(32, 32), # all images will be resized to 150x150 batch_size=32, shuffle=True, class_mode='categorical', color_mode='grayscale') # since we use binary_crossentropy loss, we need binary labels print(train_generator.nb_class) class_count=train_generator.nb_class # print(train_generator.class_indices) # print(type(train_generator.class_indices)) np.save('class_indices.txt', train_generator.class_indices)'''class_indices=np.load('class_indices.txt.npy') print(class_indices) # print(type(class_indices)) class_indices=class_indices.tolist() # print(type(class_indices)) value_indices={v:k for k,v in class_indices.items()} '''# exit()validation_generator = datagen.flow_from_directory( r'chars_rec\valication', # this is the target directory target_size=(32, 32), # all images will be resized to 150x150 batch_size=32, class_mode='categorical', color_mode='grayscale') # since we use binary_crossentropy loss, we need binary labels######################################################model = Sequential()model.add(Conv2D(32, 3, 3, input_shape=( 32, 32, 1), border_mode='same', activation='relu'))model.add(MaxPooling2D(pool_size=(2, 2)))model.add(Conv2D(32, 3, 3, border_mode='same', activation='relu'))model.add(MaxPooling2D(pool_size=(2, 2)))model.add(Conv2D(64, 3, 3, border_mode='same', activation='relu'))model.add(MaxPooling2D(pool_size=(2, 2)))model.add(Flatten()) # this converts our 3D feature maps to 1D feature vectorsmodel.add(Dense(128))model.add(Activation('relu'))model.add(Dropout(0.5))model.add(Dense(class_count))model.add(Activation('softmax'))################################################### model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy']) # 用于保存验证集误差最小的参数,当验证集误差减少时,立马保存下来 checkpointer = ModelCheckpoint(filepath="chars_rec.hdf5", verbose=1, save_best_only=True, ) history = LossHistory() if os.path.exists('chars_rec.hdf5'): model = load_model('chars_rec.hdf5') model.fit_generator( train_generator, # steps_per_epoch=2000 ,#// batch_size, # epochs=50, samples_per_epoch=9150 ,#// batch_size, nb_epoch=500, validation_data=validation_generator, nb_val_samples=1062, callbacks=[checkpointer, history, plotter] )#// batch_size)validation_steps=800 model.save('chars_rec_end.hdf5')

代码2:文本标引

def getXY_gen(batch_size=32): # f = file(".\\SheKeYuan_YinWen0303_train.utf8") # data = f.read()[0:].decode('utf-8') # f.close() f = open(".\\train_470000_0427.utf8",'r',encoding='utf-8') lines = f.readlines() f.close() # print(lines[0:10]) # exit() X=[] Y=[] for i in range(len(lines)):#tqdm(range(len(lines))): line=lines[i].strip() # print(i+1)#,line) # if i>=100000: # break x = [] y = [] y_temp = [] for j,string in enumerate(line.split(' ')): if string.find('=')1 and label_num


【本文地址】

公司简介

联系我们

今日新闻

    推荐新闻

    专题文章
      CopyRight 2018-2019 实验室设备网 版权所有